-
Notifications
You must be signed in to change notification settings - Fork 269
[CK_Tile] Support for a4w4 (fp4) in block scale gemm AB quant #3603
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
…speed up compile times
… due to larger mfma tile size
| using LargestInputType = largest_type_t<ADataType, BDataType>; | ||
| if constexpr(is_packed_type_v<LargestInputType>) | ||
| { | ||
| return t<fp8_t>{}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wouldn't expect this to change anytime soon, but for maintainability reasons I'd consider adding a:
static_assert(sizeof(typename LargestInputType::type) == sizeof(fp8_t));
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| { | ||
| const BDataType pk_val = b_element_op(b_k_n(index)); | ||
| const fp32x2_t fp32_val = pk_val.to_fp32x2(); | ||
| self(index) = (index[0] & 1) ? fp32_val.hi : fp32_val.lo; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For a_acc you do (index[1] & 1) and for b_acc you do (index[0] & 1). The reason is not apparent immediately and the removed hunk always did (k & 1). As you've explained to me this is because A is MxK and B is KxN.
You may want to add a comment explaining it or -even better- make the code self-explanatory by doing something like
constexpr auto A_TENSOR_K_DIM = 1;
constexpr auto B_TENSOR_K_DIM = 0;
(index[A_TENSOR_K_DIM] & 1)
(index[B_TENSOR_K_DIM] & 1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src) | ||
| { | ||
| if constexpr(std::is_same_v<SrcDataType, pk_int4_t>) | ||
| if constexpr(numeric_traits<SrcDataType>::PackedSize > 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use is_packed_type_v here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
| using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType, | ||
| BTypeToUse, | ||
| using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType, | ||
| typename Problem::ComputeDataType, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: I'm thinking whether it would make more sense to rename this to PrecomputedComputeDataType because compute is a verb and thus makes me think of a function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or ComputationDataType
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand where you are coming from, but ComputeDataType is the existing convention for the MFMA input type in CK/CK Tile.
| abquant_quantgrouped_fp4_instance_factory(lut); | ||
| abquant_quantgrouped_fp8_instance_factory(lut); | ||
| abquant_quantgrouped_bf8_instance_factory(lut); | ||
| abquant_quantgrouped_preshuffleb_fp4_instance_factory(lut); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't this and the non-preshuffleb variant be in the same file/function like we do on fp8 and bf8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I split them specifically since the preshuffleb pipeline is really slow to compile. This way it can already start compiling simultaneously with the other instances, and we don't extend compile times by a single translation unit taking longer than necessary. Other instances (e.g. bquant instances) the preshuffleb are also split.
So for consistency actually we could also split fp8/bf8 instances to a preshuffle-specific file
| std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; | ||
|
|
||
| // Calculate and display reference timing | ||
| using DurationType = std::chrono::duration<double>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could directly use std::chrono::milliseconds
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would only give millisecond precision right? (Not that this does crucial timing, but I do look at 0.1ms precision)
Proposed changes
Support for packed 4-bit floating point for both A and B tensors in block scale gemm. Tested with A using 1D block scale and B using 2D block scale. Works for both the "regular" and Preshuffle-B pipelines. Note that the regular pipeline stores data in fp8 in LDS (as this is how int4 was implemented). The WP pipeline stores tensor A in fp4 in LDS and dequants in when loading to registers.
Changes include:
InterleavedPKTypeLoaderfor generic type conversions instead of just int4TEST_convert_with_table.Checklist
Please put an
xinto the boxes that apply. You can also fill these out after creating the PR. If you're not sure, please don't hesitate to ask.clang-formaton all changed filesDiscussion
If this is a relatively large or complex change, feel free to start a discussion by explaining why you chose the solution you did and what alternatives you considered